{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch Implementation of Curiosity-Driven Exploration by Self-Supervised Prediction\n",
    "### Trained to play Super Mario Bros. with and without game-generated explicit rewards.\n",
    "#### Successfully learns to progress through game with just intrinsic (curiosity) rewards.\n",
    "- Paper: \"Curiosity-driven Exploration by Self-supervised Prediction\" Pathak et al 2017\n",
    "\n",
    "This implementation follows the paper almost exactly, however, we use a Deep Q-network as the agent rather than A3C for simplicity. We also do not utilize an LSTM layer in the Q-network for simplicity and it is unnecessary for Super Mario Brothers. \n",
    "\n",
    "You can train this model on a modern laptop for a few thousand iterations (will take 30+ minutes) and can already see interesting results (i.e. the agent will be obviously much better than the random agent, making relatively consistent forward progress, jumping over/on enemies and over obstacles. To match the reference paper's result you will need to train much longer using a GPU.\n",
    "\n",
    "If you like this, please check out our book, [Deep Reinforcement Learning in Action](https://www.manning.com/books/deep-reinforcement-learning-in-action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize\n",
    "import numpy as np\n",
    "from random import shuffle\n",
    "from collections import deque\n",
    "from IPython import display\n",
    "\n",
    "import gym\n",
    "from nes_py.wrappers import JoypadSpace\n",
    "import gym_super_mario_bros\n",
    "from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT\n",
    "env = gym_super_mario_bros.make('SuperMarioBros-v0')\n",
    "env = JoypadSpace(env, COMPLEX_MOVEMENT)\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test the environment by playing a random agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Random agent. Just for testing\n",
    "done = True\n",
    "for step in range(5000):\n",
    "    if done:\n",
    "        state = env.reset()\n",
    "    state, reward, done, info = env.step(env.action_space.sample())\n",
    "    env.render()\n",
    "#env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def downscale_obs(obs, new_size=(42,42), to_gray=True):\n",
    "    \"\"\"\n",
    "    downscale_obs: rescales RGB image to lower dimensions with option to change to grayscale\n",
    "    \n",
    "    obs: Numpy array or PyTorch Tensor of dimensions Ht x Wt x 3 (channels)\n",
    "    \n",
    "    to_gray: if True, will use max to sum along channel dimension for greatest contrast\n",
    "    \"\"\"\n",
    "    if to_gray:\n",
    "        return resize(obs, new_size, anti_aliasing=True).max(axis=2)\n",
    "    else:\n",
    "        return resize(obs, new_size, anti_aliasing=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Q-network module\n",
    "\n",
    "- Input: x (Tensor dims: Batch x (3) Channels x (42) Ht x (42) Wt\n",
    "- Output: Batch x 12 (Q values per action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Qnetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Qnetwork, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.linear1 = nn.Linear(288,100)\n",
    "        self.linear2 = nn.Linear(100,12)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y))\n",
    "        y = y.flatten(start_dim=2)\n",
    "        y = y.view(y.shape[0], -1, 32)\n",
    "        y = y.flatten(start_dim=1)\n",
    "        y = F.elu(self.linear1(y))\n",
    "        y = self.linear2(y) #size N, 12\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Phi (encoder) network\n",
    "- Input: A state of dimensions Batch x (3) Channels x 42 (Ht) x 42 (Wt)\n",
    "- Output: Encoded (lower-dimensional) state of dimension Batch x 288"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Phi(nn.Module): # (raw state) -> low dim state\n",
    "    def __init__(self):\n",
    "        super(Phi, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "        self.conv4 = nn.Conv2d(32, 32, kernel_size=(3,3), stride=2, padding=1)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = F.normalize(x)\n",
    "        y = F.elu(self.conv1(x))\n",
    "        y = F.elu(self.conv2(y))\n",
    "        y = F.elu(self.conv3(y))\n",
    "        y = F.elu(self.conv4(y)) #size [1, 32, 3, 3] batch, channels, 3 x 3\n",
    "        y = y.flatten(start_dim=1) #size N, 288\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Inverse model:  $g(\\phi(S_t), \\phi(S_{t+1})$\n",
    "- Input 1: Encoded state1 $\\phi(\\text{State}_t)$ of dimension Batch x 288\n",
    "- Input 2: Encoded state2 $\\phi(\\text{State}_{t+1})$\n",
    "- Output: Predicted action that was taken to get from $S_t$ to $S_{t+1}$. That is, a softmax over actions, dimensions Batch x 12 (for 12 discrete actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Gnet(nn.Module): #Inverse model: (phi_state1, phi_state2) -> action\n",
    "    def __init__(self):\n",
    "        super(Gnet, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.linear1 = nn.Linear(576,256)\n",
    "        self.linear2 = nn.Linear(256,12)\n",
    "\n",
    "    def forward(self, state1,state2):\n",
    "        x = torch.cat( (state1, state2) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        y = F.softmax(y,dim=1)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Forward Model:  $f(\\phi(s_t), a_t)$\n",
    "- Input 1: Encoded state \\phi(s_t) of dimension Batch x 288\n",
    "- Input 2: (Integer-encoded 0-11 for 12 discrete actions) action of dimension Batch x 12\n",
    "- Output: Predicted encoded next state $\\phi(s_{t+1})$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Fnet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Fnet, self).__init__()\n",
    "        #in_channels, out_channels, kernel_size, stride=1, padding=0\n",
    "        self.linear1 = nn.Linear(300,256)\n",
    "        self.linear2 = nn.Linear(256,288)\n",
    "\n",
    "    def forward(self,state,action):\n",
    "        action_ = torch.zeros(action.shape[0],12)\n",
    "        indices = torch.stack( (torch.arange(action.shape[0]), action.squeeze()), dim=0)\n",
    "        indices = indices.tolist()\n",
    "        action_[indices] = 1.\n",
    "        x = torch.cat( (state,action_) ,dim=1)\n",
    "        y = F.relu(self.linear1(x))\n",
    "        y = self.linear2(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(qvalues, eps=None): #Epsilon greedy\n",
    "    \"\"\"\n",
    "    policy(qvales, eps=None) takes Q-values and produces an integer representing action\n",
    "    \n",
    "    The function takes a vector of dimension (12,) representing Q-values for each of the 12 discrete actions\n",
    "    and returns an integer. If `eps` is supplied it follows an epsilon-greedy policy with probability `eps`. If `eps`\n",
    "    is not supplied, then a softmax policy is used.\n",
    "    \n",
    "    \"\"\"\n",
    "    if eps is not None:\n",
    "        if torch.rand(1) < eps:\n",
    "            return torch.randint(low=0,high=7,size=(1,))\n",
    "        else:\n",
    "            return torch.argmax(qvalues)\n",
    "    else:\n",
    "        return torch.multinomial(F.softmax(F.normalize(qvalues)), num_samples=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experience replay memory\n",
    "\n",
    "The experience replay memory stores a list of tuples where each tuple is a single experience from an initial state 1, an action taken, the resulting state 2 and reward: $(S_t, a_t, R_{t+1}, S_{t+1})$\n",
    "\n",
    "\n",
    "The `add_memory(state1, action, reward, state2)` method adds a memory to the memory list. The memory list has a fixed maximum length, if it's full, new memories will randomly overwrite old memories.\n",
    "\n",
    "The `get_batch()` method returns a random subset from the memory list for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExperienceReplay:\n",
    "    def __init__(self, N=500, batch_size=100):\n",
    "        self.N = N #total memory size\n",
    "        self.batch_size = batch_size\n",
    "        self.memory = [] #list of tuples of tensors (S_t, a_t, R_{t+1}, S_{t+1})\n",
    "        self.counter = 0\n",
    "        #S_t should be size B x Channel x Ht x Wt. R_t : B x 1\n",
    "        \n",
    "    def add_memory(self, state1, action, reward, state2):\n",
    "        self.counter +=1 \n",
    "        if self.counter % 500 == 0:\n",
    "            self.shuffle_memory()\n",
    "            \n",
    "        if len(self.memory) < self.N:\n",
    "            self.memory.append( (state1, action, reward, state2) )\n",
    "        else:\n",
    "            rand_index = np.random.randint(0,self.N-1)\n",
    "            self.memory[rand_index] = (state1, action, reward, state2) #replace random memory\n",
    "    \n",
    "    def shuffle_memory(self):\n",
    "        shuffle(self.memory)\n",
    "        \n",
    "    def get_batch(self):\n",
    "        if len(self.memory) < self.batch_size:\n",
    "            batch_size = len(self.memory)\n",
    "        else:\n",
    "            batch_size = self.batch_size\n",
    "        if len(self.memory) < 1:\n",
    "            print(\"Error: No data in memory.\")\n",
    "            return None\n",
    "\n",
    "        ind = np.random.choice(np.arange(len(self.memory)),batch_size,replace=False)\n",
    "        batch = [self.memory[i] for i in ind] #batch is a list of tuples\n",
    "        state1_batch = torch.stack([x[0].squeeze(dim=0) for x in batch],dim=0)\n",
    "        action_batch = torch.Tensor([x[1] for x in batch]).long()\n",
    "        reward_batch = torch.Tensor([x[2] for x in batch])\n",
    "        state2_batch = torch.stack([x[3].squeeze(dim=0) for x in batch],dim=0)\n",
    "        return state1_batch, action_batch, reward_batch, state2_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define basic hyperparameters. See reference paper for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'batch_size':150,\n",
    "    'beta':0.2,\n",
    "    'lambda':0.1,\n",
    "    'eta': 1.0,\n",
    "    'gamma':0.2,\n",
    "    'max_episode_len':100,\n",
    "    'min_progress':15,\n",
    "    'action_repeats':6,\n",
    "    'frames_per_state':3\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Instantiate the 4 modules and the experience replay buffer. Setup the optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "replay = ExperienceReplay(N=1000, batch_size=params['batch_size'])\n",
    "Qmodel = Qnetwork()\n",
    "encoder = Phi()\n",
    "forward_model = Fnet()\n",
    "inverse_model = Gnet()\n",
    "forward_loss = nn.MSELoss(reduction='none')#torch.nn.PairwiseDistance()#\n",
    "inverse_loss = nn.CrossEntropyLoss(reduction='none')\n",
    "qloss = nn.MSELoss()\n",
    "# We can add the model parameters from each model to a list and pass that to a single optimizer\n",
    "all_model_params = list(Qmodel.parameters()) + list(encoder.parameters()) \n",
    "all_model_params += list(forward_model.parameters()) + list(inverse_model.parameters())\n",
    "opt = optim.Adam(lr=0.001, params=all_model_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ICM(state1, action, state2, forward_scale=1., inverse_scale=1e4): #action is an integer [0:11]\n",
    "    \"\"\"\n",
    "    Intrinsic Curiosity Module (ICM): Calculates prediction error for forward and inverse dynamics\n",
    "    \n",
    "    The ICM takes a state1, the action that was taken, and the resulting state2 as inputs\n",
    "    (from experience replay memory) and uses the forward and inverse models to calculate the prediction error\n",
    "    and train the encoder to only pay attention to details in the environment that are controll-able (i.e. it should\n",
    "    learn to ignore useless stochasticity in the environment and not encode that).\n",
    "    \n",
    "    \"\"\"\n",
    "    state1_hat = encoder(state1)\n",
    "    state2_hat = encoder(state2)\n",
    "    #Forward model prediction error\n",
    "    state2_hat_pred = forward_model(state1_hat.detach(), action.detach())\n",
    "    forward_pred_err = forward_scale * forward_loss(state2_hat_pred, \\\n",
    "                        state2_hat.detach()).sum(dim=1).unsqueeze(dim=1)\n",
    "    #Inverse model prediction error\n",
    "    pred_action = inverse_model(state1_hat, state2_hat) #returns softmax over actions\n",
    "    inverse_pred_err = inverse_scale * inverse_loss(pred_action, \\\n",
    "                                        action.detach().flatten()).unsqueeze(dim=1)\n",
    "    return forward_pred_err, inverse_pred_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(q_loss, inverse_loss, forward_loss):\n",
    "    \"\"\"\n",
    "    Overall loss function to optimize for all 4 modules\n",
    "    \n",
    "    Loss function based on calculation in paper\n",
    "    \"\"\"\n",
    "    loss_ = (1 - params['beta']) * inverse_loss\n",
    "    loss_ += params['beta'] * forward_loss\n",
    "    loss_ = loss_.sum() / loss_.flatten().shape[0]\n",
    "    loss = loss_ + params['lambda'] * q_loss\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_state(state):\n",
    "    \"\"\"\n",
    "    First downscale state, convert to grayscale, convert to torch tensor and add batch dimension\n",
    "    \"\"\"\n",
    "    return torch.from_numpy(downscale_obs(state, to_gray=True)).float().unsqueeze(dim=0)\n",
    "\n",
    "def prepare_multi_state(state1, state2):\n",
    "    \"\"\"\n",
    "    Prepare a 3 channel state (for use in inference not training).\n",
    "    \n",
    "    The Q-model and encoder/Phi model expect the input state to have 3 channels. Following the reference paper,\n",
    "    these models are fed 3 consecutive state frames to give the model's access to motion information \n",
    "    (i.e. velocity information rather than just positional information)\n",
    "    \"\"\"\n",
    "    #prev is 1x3x42x42\n",
    "    state1 = state1.clone()\n",
    "    tmp = torch.from_numpy(downscale_obs(state2, to_gray=True)).float()\n",
    "    #shift data along tensor to accomodate newest observation (we could have used deque w/ maxlen 3)\n",
    "    state1[0][0] = state1[0][1]\n",
    "    state1[0][1] = state1[0][2]\n",
    "    state1[0][2] = tmp #replace last frame\n",
    "    return state1\n",
    "\n",
    "def prepare_initial_state(state,N=3):\n",
    "    \"\"\"\n",
    "    Prepares the initial state which is just a tensor of 1 (Batch) x 3 x 42 x 42\n",
    "    \n",
    "    The channel dimension is just a copy of the input state 3 times\n",
    "    \n",
    "    \"\"\"\n",
    "    #state should be 42x42 array\n",
    "    state_ = torch.from_numpy(downscale_obs(state, to_gray=True)).float()\n",
    "    tmp = state_.repeat((N,1,1)) #now 3x42x42\n",
    "    return tmp.unsqueeze(dim=0) #now 1x3x42x42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reset_env():\n",
    "    \"\"\"\n",
    "    Reset the environment and return a new initial state\n",
    "    \"\"\"\n",
    "    env.reset()\n",
    "    state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "    return state1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Part of main training loop. \n",
    "This code extracts a minibatch from the experience replay memory and runs the 4 modules forward and calculates the prediction errors for each, returning them as output.\n",
    "\n",
    "If `use_explicit` is set to `True` then the reward will include the game-generated explicit reward. If set to `False` then the agent will learn only based on the instrinsic (curiosity) based prediction-error reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def minibatch_train(use_extrinsic=True):\n",
    "    state1_batch, action_batch, reward_batch, state2_batch = replay.get_batch() \n",
    "    action_batch = action_batch.view(action_batch.shape[0],1)\n",
    "    reward_batch = reward_batch.view(reward_batch.shape[0],1)\n",
    "    #replay.get_batch returns tuple (state1, action, reward, state2) where each tensor has batch dimension\n",
    "    forward_pred_err, inverse_pred_err = ICM(state1_batch, action_batch, state2_batch) #internal curiosity module\n",
    "    i_reward = (1. / params['eta']) * forward_pred_err\n",
    "    reward = i_reward.detach()\n",
    "    if use_extrinsic:\n",
    "        reward += reward_batch \n",
    "    qvals = Qmodel(state2_batch)\n",
    "    reward += params['gamma'] * torch.max(qvals)\n",
    "    reward_pred = Qmodel(state1_batch)\n",
    "    reward_target = reward_pred.clone()\n",
    "    indices = torch.stack( (torch.arange(action_batch.shape[0]), action_batch.squeeze()), dim=0)\n",
    "    indices = indices.tolist()\n",
    "    reward_target[indices] = reward.squeeze()\n",
    "    q_loss = 1e5 * qloss(F.normalize(reward_pred), F.normalize(reward_target.detach()))\n",
    "    return forward_pred_err, inverse_pred_err, q_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main Training Loop\n",
    "\n",
    "Training details to note:\n",
    "- Training starts with softmax action policy, and then after 1000 steps (or whatever you set it to) it will switch to \n",
    "epsilon greedy policy. This empirically seems to help with exploration in the beginning and then epsilon-greedy helps increase the explotation. You should try starting with a softmax policy and increase the temperature slowly during training to make a continuous/smooth version of this approach.\n",
    "\n",
    "- We use a deque (a list-like data structure where we can specify a maximum length, append items and old items will\n",
    "automatically get pushed out once the max len is hit) to store the last 3 frames and then package it into a tensor for use as the state.\n",
    "\n",
    "- We keep track of the `last_x_pos` from the prior 50 frames. If the agent hasn't made significant forward progress (x_now - last_x_pos) > 10 then we assume the agent is stuck and we reset the environment.\n",
    "\n",
    "- Following the reference paper, we use deterministic sticky-actions such that if the policy says do action 0, then we repeat that action 6 times (only during training, during inference we just take the action once). This helps since each action is a very small step in any direction so by compounding them in training the agent can learn faster what the actions are doing.\n",
    "\n",
    "- If you train with intrinsic reward only, this implementation is not very stable. Sometimes the agent just repeatedly does something stupid and will not learn, and other times it will do very well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/skimage/transform/_warps.py:105: UserWarning: The default mode, 'constant', will be changed to 'reflect' in skimage 0.15.\n",
      "  warn(\"The default mode, 'constant', will be changed to 'reflect' in \"\n",
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/ipykernel/__main__.py:16: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 250, Loss: 5531.4599609375\n",
      "Forward loss: 3.2666757106781006 \n",
      " Inverse loss: 24682.59375 \n",
      " Qloss: 5923.2978515625\n",
      "{'coins': 1, 'flag_get': False, 'life': 2, 'score': 1350, 'stage': 1, 'status': 'tall', 'time': 328, 'world': 1, 'x_pos': 281}\n",
      "Episode over.\n",
      "Epoch 500, Loss: 4975.23681640625\n",
      "Forward loss: 0.11121562123298645 \n",
      " Inverse loss: 24741.8515625 \n",
      " Qloss: 267.779541015625\n",
      "{'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 370, 'world': 1, 'x_pos': 222}\n",
      "Episode over.\n",
      "Epoch 750, Loss: 4957.77734375\n",
      "Forward loss: 0.2258693426847458 \n",
      " Inverse loss: 24692.28515625 \n",
      " Qloss: 191.37889099121094\n",
      "{'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 385, 'world': 1, 'x_pos': 50}\n",
      "Epoch 1000, Loss: 5014.279296875\n",
      "Forward loss: 0.06708843261003494 \n",
      " Inverse loss: 24925.478515625 \n",
      " Qloss: 291.2850036621094\n",
      "{'coins': 1, 'flag_get': False, 'life': 2, 'score': 1400, 'stage': 1, 'status': 'tall', 'time': 313, 'world': 1, 'x_pos': 425}\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-16-8342a7c2212c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     53\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[0mlosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m     \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     56\u001b[0m     \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/deeprl/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    100\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m         \"\"\"\n\u001b[0;32m--> 102\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/deeprl/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     91\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     92\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "epochs = 2500\n",
    "env.reset()\n",
    "state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "eps=0.15\n",
    "losses = []\n",
    "ep_lengths = []\n",
    "episode_length = 0\n",
    "switch_to_eps_greedy = 1000\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "e_reward = 0.\n",
    "last_x_pos = env.env.env._x_position\n",
    "for i in range(epochs):\n",
    "    opt.zero_grad()\n",
    "    episode_length += 1\n",
    "    q_val_pred = Qmodel(state1)\n",
    "    if i > switch_to_eps_greedy:\n",
    "        action = int(policy(q_val_pred,eps))\n",
    "    else:\n",
    "        action = int(policy(q_val_pred))\n",
    "    for j in range(params['action_repeats']):\n",
    "        state2, e_reward_, done, info = env.step(action)\n",
    "        if done:\n",
    "            state1 = reset_env()\n",
    "            break\n",
    "        e_reward += e_reward_\n",
    "        state_deque.append(prepare_state(state2))\n",
    "    state2 = torch.stack(list(state_deque),dim=1)\n",
    "    replay.add_memory(state1, action, e_reward, state2)\n",
    "    e_reward = 0\n",
    "    if i % params['max_episode_len'] == 0 and i != 0:\n",
    "        if (info['x_pos'] - last_x_pos) < params['min_progress']:\n",
    "            done = True\n",
    "        else:\n",
    "            last_x_pos = info['x_pos']\n",
    "    if done:\n",
    "        print(\"Episode over.\")\n",
    "        ep_lengths.append(info['x_pos'])\n",
    "        state1 = reset_env()\n",
    "        last_x_pos = env.env.env._x_position\n",
    "        episode_length = 0\n",
    "    else:\n",
    "        state1 = state2\n",
    "    #Enter mini-batch training\n",
    "    if len(replay.memory) < params['batch_size']:\n",
    "        continue\n",
    "    \n",
    "    forward_pred_err, inverse_pred_err, q_loss = minibatch_train(use_extrinsic=False)\n",
    "    loss = loss_fn(q_loss, forward_pred_err, inverse_pred_err)\n",
    "    loss_list = (q_loss.mean(), forward_pred_err.flatten().mean(), inverse_pred_err.flatten().mean(), episode_length)\n",
    "    if i % 250 == 0:\n",
    "        print(\"Epoch {}, Loss: {}\".format(i,loss))\n",
    "        print(\"Forward loss: {} \\n Inverse loss: {} \\n Qloss: {}\".format(\\\n",
    "                             forward_pred_err.mean(),inverse_pred_err.mean(),q_loss.mean()))\n",
    "        print(info)\n",
    "    losses.append(loss_list)\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot losses for each module\n",
    "\n",
    "Loss plots will look much cleaner if you train using explicit rewards too.\n",
    "Forward loss will generally decrease steadily. Q-learning will decrease but more erratically. Inverse model loss decreases rapidly initially and then plateaus. Note that the encoder/Phi model is trained via the inverse model (both are trained together), it does not have it's own loss.\n",
    "\n",
    "Note! These are log-transformed plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses_ = np.array(losses)\n",
    "ep_lengths_ = np.array(ep_lengths)\n",
    "plt.figure(figsize=(8,6))\n",
    "plt.plot(np.log(losses_[:,0]),label='Q loss')\n",
    "plt.plot(np.log(losses_[:,1]),label='Forward loss')\n",
    "plt.plot(np.log(losses_[:,2]),label='Inverse loss')\n",
    "#plt.plot(ep_lengths_, label='Episode Length')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/brandonbrown/anaconda3/envs/deeprl/lib/python3.6/site-packages/skimage/transform/_warps.py:105: UserWarning: The default mode, 'constant', will be changed to 'reflect' in skimage 0.15.\n",
      "  warn(\"The default mode, 'constant', will be changed to 'reflect' in \"\n"
     ]
    }
   ],
   "source": [
    "#Test model\n",
    "eps=0.15\n",
    "done = True\n",
    "state_deque = deque(maxlen=params['frames_per_state'])\n",
    "for step in range(5000):\n",
    "    if done:\n",
    "        env.reset()\n",
    "        state1 = prepare_initial_state(env.render('rgb_array'))\n",
    "    q_val_pred = Qmodel(state1)\n",
    "    action = int(policy(q_val_pred,eps))\n",
    "    state2, reward, done, info = env.step(action)\n",
    "    state2 = prepare_multi_state(state1,state2)\n",
    "    state1=state2\n",
    "    env.render()\n",
    "#env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Miscellaneous / Unused"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(q,tau=1.4): #q is vector\n",
    "    q = F.normalize(q)\n",
    "    return torch.exp(q/tau) / torch.sum(torch.exp(q/tau))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_encoder(from_replay=False):\n",
    "    \"\"\"\n",
    "    Test's the encoder's ability to disentangle similar states.\n",
    "    \n",
    "    If the encoder is being properly trained, it should learn to encode similar states such that their \n",
    "    euclidian distance is relatively large, so that the the forward network and inverse network can make\n",
    "    better predictions. You'll notice that if you run the function before or early during training \n",
    "    (make sure from_replay is false since the replay will be empty before training) the euclidian distance\n",
    "    between two states will be small, but during training the encoder will learn to disentangle these and the \n",
    "    distance will increase.\n",
    "    \n",
    "    `from_replay=True` will test 2 consecutive states from the replay memory otherwise will just reset environment\n",
    "    and use initial two states after taking action\n",
    "    \"\"\"\n",
    "    if from_replay:\n",
    "        assert len(replay.memory) > 0\n",
    "        s1, a, r, s2 = replay.memory[np.random.randint(len(replay.memory))]\n",
    "    else:\n",
    "        env.reset()\n",
    "        s1 = prepare_initial_state(env.render('rgb_array'))\n",
    "        env.reset()\n",
    "        env.step(3)\n",
    "        s2 = prepare_multi_state(s1, env.render('rgb_array'))\n",
    "        \n",
    "    return nn.MSELoss(reduction='mean')(encoder(s1),encoder(s2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_encoder(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Code to save or load model parameters after training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Save model parameters\n",
    "torch.save(Qmodel.state_dict(),'Qmodel_')\n",
    "torch.save(encoder.state_dict(),'encoder_')\n",
    "torch.save(forward_model.state_dict(),'Fnet_')\n",
    "torch.save(inverse_model.state_dict(),'Gnet_')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load model parameters from file\n",
    "model.load_state_dict(torch.load('Qmodel_'))\n",
    "model.load_state_dict(torch.load('encoder_'))\n",
    "model.load_state_dict(torch.load('Fnet_'))\n",
    "model.load_state_dict(torch.load('Gnet_'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# References"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- \"Curiosity-driven Exploration by Self-supervised Prediction\" Pathak et al 2017"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:deeprl]",
   "language": "python",
   "name": "conda-env-deeprl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
